{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 5.2 FETI-DP in NGSolve I: Working with Point-Constraints\n",
    "\n",
    "We implement standard FETI-DP, using only point-constraints, for Poisson's equation, in 2D.\n",
    "\n",
    "This is easily done with NGS-Py, because:\n",
    "\n",
    " - We need access to sub-assembled matrices for each subdomain, which is exactly how matrices are stored in NGSolve.\n",
    " - In fact, we can just reuse these local matrices for the dual-primal space.\n",
    " - We do not need to concern ourselfs with finding any sort of global enumeration in the dual-primal space, because\n",
    "   NGSolve does not work with one. We only need ParallelDofs for the dual-primal space.\n",
    "   \n",
    "## Use the cells at the bottom of this file to experiment with bigger computations!\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_procs = '40'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "from usrmeeting_jupyterstuff import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "stop_cluster()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Waiting for connection file: ~/.ipython/profile_ngsolve/security/ipcontroller-kogler-client.json\n",
      "connecting ... try:6 succeeded!"
     ]
    }
   ],
   "source": [
    "start_cluster(num_procs)\n",
    "connect_cluster()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[stdout:10] \n",
      "global,  ndof = 187833 , lodofs = 47161\n",
      "avg DOFs per core:  4823.075\n"
     ]
    }
   ],
   "source": [
    "%%px\n",
    "from ngsolve import *\n",
    "import netgen.meshing as ngmeshing\n",
    "\n",
    "comm = MPI_Init()\n",
    "nref = 1\n",
    "\n",
    "ngmesh = ngmeshing.Mesh(dim=2)\n",
    "ngmesh.Load('squaref.vol')\n",
    "for l in range(nref):\n",
    "    ngmesh.Refine()\n",
    "mesh = Mesh(ngmesh)\n",
    "\n",
    "comm = MPI_Init()\n",
    "fes = H1(mesh, order=2, dirichlet='right|top')\n",
    "a = BilinearForm(fes)\n",
    "u,v = fes.TnT()\n",
    "a += SymbolicBFI(grad(u)*grad(v))\n",
    "a.Assemble()\n",
    "f = LinearForm(fes)\n",
    "f += SymbolicLFI(x*y*v)\n",
    "f.Assemble()\n",
    "avg_dof = comm.Sum(fes.ndof) / comm.size\n",
    "if comm.rank==0:\n",
    "    print('global,  ndof =', fes.ndofglobal, ', lodofs =', fes.lospace.ndofglobal)\n",
    "    print('avg DOFs per core: ', avg_dof)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Finding the primal DOFs\n",
    "\n",
    "First, we have to find out which DOFs sit on subdomain vertices. \n",
    "\n",
    "In 2 dimensions, those are just the ones that are shared between 3 or more ranks.\n",
    "\n",
    "Additionally, we only need to concern ourselfs with non-dirichlet DOFs.\n",
    "\n",
    "We simply construct a BitArray, which is True wherever we have a primal DOF, and else False."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[stdout:10] \n",
      "# primal dofs global:  56\n",
      "min, avg, max per rank:  1 4.225 7\n"
     ]
    }
   ],
   "source": [
    "%%px\n",
    "primal_dofs = BitArray([len(fes.ParallelDofs().Dof2Proc(k))>1 for k in range(fes.ndof)]) & fes.FreeDofs()\n",
    "nprim = comm.Sum(sum([1 for k in range(fes.ndof) if primal_dofs[k] \\\n",
    "                      and comm.rank<fes.ParallelDofs().Dof2Proc(k)[0] ]))\n",
    "npmin = comm.Min(primal_dofs.NumSet() if comm.rank else nprim)\n",
    "npavg = comm.Sum(primal_dofs.NumSet())/comm.size\n",
    "npmax = comm.Max(primal_dofs.NumSet())\n",
    "if comm.rank==0:\n",
    "    print('# primal dofs global: ', nprim)  \n",
    "    print('min, avg, max per rank: ', npmin, npavg, npmax)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setting up the Dual-Primal-Space Matrix\n",
    "\n",
    "We want to assemble the stiffness-matrix on the dual-primal space\n",
    "\n",
    "$$\n",
    "V_{\\scriptscriptstyle DP} = \\{u\\in W : u \\text{ is continuous at subdomain vertices} \\}\n",
    "$$\n",
    "\n",
    "This means that we need:\n",
    "\n",
    " - local stiffness matrices:\n",
    " - ParallelDofs\n",
    "\n",
    "If we restrict $V_{\\scriptscriptstyle DP}$ to some subdomain, we get the same space\n",
    "we get from restricting $V$, so the local matrices are the same!\n",
    "\n",
    "The coupling of DOFs between the subdomains is now different. We need to construct new ParallelDofs\n",
    "\n",
    "![](tix_fetidp_standalone.png)\n",
    "\n",
    "If we think of ParallelDofs as the connection between DOFs of different subdomains,\n",
    "the dual-primal ParallelDofs are simply a subset of the ones in the primal space."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[stdout:10] \n",
      "global ndof in original space   :  187833\n",
      "global ndof in dual-primal space:  192810\n"
     ]
    }
   ],
   "source": [
    "%%px\n",
    "dp_pardofs = fes.ParallelDofs().SubSet(primal_dofs)\n",
    "if comm.rank==0:\n",
    "    print('global ndof in original space   : ', fes.ParallelDofs().ndofglobal)\n",
    "    print('global ndof in dual-primal space: ', dp_pardofs.ndofglobal)    "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### The ParallelMatrix for the DP-Space\n",
    "\n",
    "A ParallelMatrix is just a local matrix + ParallelDofs.\n",
    "We have access to both."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%px\n",
    "from ngsolve.la import ParallelMatrix\n",
    "A_dp = ParallelMatrix(a.mat.local_mat, dp_pardofs)\n",
    "A_dp_inv = A_dp.Inverse(fes.FreeDofs(), inverse='mumps')  "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### The Jump-Operator $B$\n",
    "\n",
    "The Jump-operator is implemented on C++-side. The constraints are fully redundant.\n",
    "The jump operator builds constraints for **all** connections in the\n",
    "given ParallelDofs. In our case, we only need the constraints for the non-primal DOFs.\n",
    "\n",
    "In order for CreateRowVector to be properly defined,\n",
    "the jump-operator has to be told in which space the \"u-vectors\" live. \n",
    "\n",
    "This is done by setting \"u_pardofs\". If none are given, or they are explicitely \n",
    "given as \"None\", B assumes a fully discontinuous u-space."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[stdout:1] rank 8 , # of local multipliers =  283\n",
      "[stdout:2] rank 6 , # of local multipliers =  186\n",
      "[stdout:4] rank 7 , # of local multipliers =  214\n",
      "[stdout:7] rank 5 , # of local multipliers =  170\n",
      "[stdout:9] rank 4 , # of local multipliers =  279\n",
      "[stdout:10] # of global multipliers = : 4968\n",
      "[stdout:12] rank 1 , # of local multipliers =  259\n",
      "[stdout:14] rank 2 , # of local multipliers =  204\n",
      "[stdout:15] rank 3 , # of local multipliers =  293\n",
      "[stdout:17] rank 9 , # of local multipliers =  287\n"
     ]
    }
   ],
   "source": [
    "%%px\n",
    "from ngsolve.la import FETI_Jump\n",
    "dual_pardofs = fes.ParallelDofs().SubSet(BitArray(~primal_dofs & fes.FreeDofs()))\n",
    "B = FETI_Jump(dual_pardofs, u_pardofs=dp_pardofs)\n",
    "if comm.rank==0:\n",
    "    print('# of global multipliers = :', B.col_pardofs.ndofglobal)\n",
    "elif comm.rank<10:\n",
    "    print('rank', comm.rank, ', # of local multipliers = ', B.col_pardofs.ndoflocal)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Forming the Schur-Complement\n",
    "\n",
    "We can very easily define the (negative) Schur-Complement\n",
    "$F = B~A_{\\scriptscriptstyle DP}^{-1}~B^T$\n",
    "with the \"@\"-operator.\n",
    "\n",
    "\"@\" can concatenate any operators of \"BaseMatrix\" type.\n",
    "They have to provide a Mult method as well as CreateRow/ColVector,\n",
    "which is needed to construct vectors to store intermediate vectors.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[stdout:10] type F:  <class 'ngsolve.la.BaseMatrix'>\n"
     ]
    }
   ],
   "source": [
    "%%px \n",
    "F = B @ A_dp_inv @ B.T\n",
    "if comm.rank==0:\n",
    "    print('type F: ', type(F))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### The Preconditioner"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The Dirichlet-Preconditioner can also be written down very easily.\n",
    "\n",
    "We have, with B standing for border DOFs, and I standing for inner DOFs,\n",
    "$$\n",
    "M^{-1} = B ~ (A_{BB} - A_{BI}A_{II}^{-1}A_{IB}) ~ B^T\n",
    "$$\n",
    "\n",
    "The inner part, the Dirichlet-to-Neumann Schur-Complement, can \n",
    "be implemented with a small trick:\n",
    "\n",
    "\\begin{align*}\n",
    "\\left(\\begin{matrix}\n",
    "A_{BB} & A_{BI} \\\\\n",
    "A_{IB} & A_{II}\n",
    "\\end{matrix}\\right)\n",
    "\\cdot\n",
    "\\left[\n",
    "\\left(\\begin{matrix}\n",
    "I & 0\\\\\n",
    "0 & I\n",
    "\\end{matrix}\\right)\n",
    "-\n",
    "\\left(\\begin{matrix}\n",
    "0 & 0\\\\\n",
    "0 & A_{II}^{-1}\n",
    "\\end{matrix}\\right)\n",
    "\\cdot\n",
    "\\left(\\begin{matrix}\n",
    "A_{BB} & A_{BI} \\\\\n",
    "A_{IB} & A_{II}\n",
    "\\end{matrix}\\right)\n",
    "\\right] \n",
    "&\n",
    "= \n",
    "\\left(\\begin{matrix}\n",
    "A_{BB} & A_{BI} \\\\\n",
    "A_{IB} & A_{II}\n",
    "\\end{matrix}\\right)\n",
    "\\cdot\n",
    "\\left[\n",
    "\\left(\\begin{matrix}\n",
    "I & 0\\\\\n",
    "0 & I\n",
    "\\end{matrix}\\right)\n",
    "-\n",
    "\\left(\\begin{matrix}\n",
    "0 & 0\\\\\n",
    "A_{II}^{-1}A_{IB} & I\n",
    "\\end{matrix}\\right)\n",
    "\\right]\n",
    " = \\\\ =\n",
    "\\left(\\begin{matrix}\n",
    "A_{BB} & A_{BI} \\\\\n",
    "A_{IB} & A_{II}\n",
    "\\end{matrix}\\right)\n",
    "\\cdot\n",
    "\\left(\\begin{matrix}\n",
    "I & 0\\\\\n",
    "-A_{II}^{-1}A_{IB} & 0\n",
    "\\end{matrix}\\right)\n",
    "&\n",
    "=\n",
    "\\left(\\begin{matrix}\n",
    "A_{BB} - A_{BI}A_{II}^{-1}A_{IB} & 0\\\\\n",
    "0 & 0\n",
    "\\end{matrix}\\right)\n",
    "\\end{align*}\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%px\n",
    "innerdofs = BitArray([len(fes.ParallelDofs().Dof2Proc(k))==0 for k in range(fes.ndof)]) & fes.FreeDofs()\n",
    "A = a.mat.local_mat\n",
    "Aiinv = A.Inverse(innerdofs, inverse='sparsecholesky')\n",
    "Fhat = B @ A @ (IdentityMatrix() - Aiinv @ A) @ B.T"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Multiplicity Scaling\n",
    "\n",
    "The scaled Dirichlet-Preconditioner is just \n",
    "\n",
    "$$\n",
    "M_s^{-1} = B_s ~ (A_{BB} - A_{BI}A_{II}^{-1}A_{IB}) ~ B_s^T\n",
    "$$\n",
    "\n",
    "where \n",
    "\n",
    "$$\n",
    "B_s^T u = DB^Tu\n",
    "$$\n",
    "\n",
    "and $D = \\text{diag}([\\text{multiplicity of DOF }i\\text{ }: i=0\\ldots n-1])$.\n",
    "\n",
    "We can implement this without untroducing D as a matrix, and thus avoiding copying of entries."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%px\n",
    "class ScaledMat(BaseMatrix):\n",
    "    def __init__ (self, A, vals):\n",
    "        super(ScaledMat, self).__init__()\n",
    "        self.A = A\n",
    "        self.scaledofs = [k for k in enumerate(vals) if k[1]!=1.0]\n",
    "    def CreateRowVector(self):\n",
    "        return self.A.CreateRowVector()\n",
    "    def CreateColVector(self):\n",
    "        return self.A.CreateColVector()\n",
    "    def Mult(self, x, y):\n",
    "        y.data = self.A * x\n",
    "        for k in self.scaledofs:\n",
    "            y[k[0]] *= k[1]\n",
    "scaledA = ScaledMat(A, [1.0/(1+len(fes.ParallelDofs().Dof2Proc(k))) for k in range(fes.ndof)])\n",
    "scaledBT = ScaledMat(B.T, [1.0/(1+len(fes.ParallelDofs().Dof2Proc(k))) for k in range(fes.ndof)])\n",
    "Fhat2 = B @ scaledA @ (IdentityMatrix() - Aiinv @ A) @ scaledBT"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Prepare RHS"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%px\n",
    "hv = B.CreateRowVector()\n",
    "rhs = B.CreateColVector()\n",
    "lam = B.CreateColVector()\n",
    "hv.data = A_dp_inv * f.vec\n",
    "rhs.data = B * hv"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Solve for $\\lambda$ - without scaling"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[stdout:10] \n",
      "it =  0  err =  0.33241357601342986\n",
      "it =  1  err =  0.08073724692259508\n",
      "it =  2  err =  0.03342603729269669\n",
      "it =  3  err =  0.018590695356589717\n",
      "it =  4  err =  0.0059947474548240785\n",
      "it =  5  err =  0.0027638282106529502\n",
      "it =  6  err =  0.0014339187849704818\n",
      "it =  7  err =  0.0005940409981120754\n",
      "it =  8  err =  0.0003125855268094198\n",
      "it =  9  err =  0.0001315943979595822\n",
      "it =  10  err =  5.087321086144511e-05\n",
      "it =  11  err =  2.259284681093829e-05\n",
      "it =  12  err =  8.338649672536656e-06\n",
      "it =  13  err =  3.884661119865966e-06\n",
      "it =  14  err =  1.5186884566911164e-06\n",
      "it =  15  err =  6.629027204434754e-07\n",
      "time to solve = 0.5055452309316024\n"
     ]
    }
   ],
   "source": [
    "%%px\n",
    "tsolve1 = -comm.WTime()\n",
    "solvers.CG(mat=F, pre=Fhat, rhs=rhs, sol=lam, maxsteps=100, printrates=comm.rank==0, tol=1e-6)\n",
    "tsolve1 += comm.WTime()\n",
    "if comm.rank==0:\n",
    "    print('time to solve =', tsolve1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Solve for $\\lambda$ - with scaling"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[stdout:10] \n",
      "it =  0  err =  0.16620678800671493\n",
      "it =  1  err =  0.03922813647750846\n",
      "it =  2  err =  0.015375708271880666\n",
      "it =  3  err =  0.007954691011138264\n",
      "it =  4  err =  0.0028048597514098014\n",
      "it =  5  err =  0.001239627101399431\n",
      "it =  6  err =  0.0006206316326847184\n",
      "it =  7  err =  0.00026791930129006525\n",
      "it =  8  err =  0.00013500096491957006\n",
      "it =  9  err =  5.9146271053553835e-05\n",
      "it =  10  err =  2.3367293859660653e-05\n",
      "it =  11  err =  1.0170342441937255e-05\n",
      "it =  12  err =  3.857744385789025e-06\n",
      "it =  13  err =  1.7348456034200445e-06\n",
      "it =  14  err =  6.956273460897132e-07\n",
      "time to solve = 0.4004018809646368\n"
     ]
    }
   ],
   "source": [
    "%%px\n",
    "tsolve2 = -comm.WTime()\n",
    "solvers.MinRes(mat=F, pre=Fhat2, rhs=rhs, sol=lam, maxsteps=100, printrates=comm.rank==0, tol=1e-6)\n",
    "tsolve2 += comm.WTime()\n",
    "if comm.rank==0:\n",
    "    print('time to solve =', tsolve2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### A faster dual-primal Inverse\n",
    "\n",
    "We just inverted $A_{\\scriptscriptstyle DP}$ with MUMPS, and while this is\n",
    "faster than inverting $A$, in practice this is not scalable.\n",
    "\n",
    "MUMPS does not realize that it can eliminate all of the dual DOFs **locally**.\n",
    "\n",
    "We can implement a much faster, and more scalable, dual-primal\n",
    "inverse per hand. For that, we reduce the problem to the Schur-complement\n",
    "with resprect to the primal DOFs **using local sparse matrix factorization**, and\n",
    "then only invert the primal schur-complement $S_{\\pi\\pi}$ **globally**.\n",
    "\n",
    "In order to globally invert $S_{\\pi\\pi}$, we need its local sub-matrix as\n",
    "a sparse matrix, not as the expression\n",
    "\n",
    "         A @ (IdentityMatrix() - All_inv @ A)\n",
    "\n",
    "For that, we can just repeatedly multiply unit vectors."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%px\n",
    "from dd_toolbox import Op2SPM\n",
    "class LocGlobInverse(BaseMatrix):\n",
    "    def __init__ (self, Aglob, freedofs, invtype_loc='sparsecholesky',\\\n",
    "                  invtype_glob='masterinverse'):\n",
    "        super(LocGlobInverse, self).__init__()\n",
    "        self.Aglob = Aglob  \n",
    "        self.A = Aglob.local_mat  \n",
    "        pardofs = Aglob.col_pardofs \n",
    "        local_dofs = BitArray([len(pardofs.Dof2Proc(k))==0 for k in range(pardofs.ndoflocal)]) & freedofs\n",
    "        global_dofs = BitArray(~local_dofs & freedofs)\n",
    "        self.All_inv = self.A.Inverse(local_dofs, inverse=invtype_loc) \n",
    "        sp_loc_asmult = self.A @ (IdentityMatrix() - self.All_inv @ self.A)\n",
    "        sp_loc = Op2SPM(sp_loc_asmult, global_dofs, global_dofs)\n",
    "        sp = ParallelMatrix(sp_loc, pardofs)\n",
    "        self.Sp_inv = sp.Inverse(global_dofs, inverse=invtype_glob)\n",
    "        self.tv = self.Aglob.CreateRowVector()\n",
    "        self.btilde = self.Aglob.CreateRowVector()\n",
    "    def CreateRowVector(self):\n",
    "        return self.Aglob.CreateRowVector()\n",
    "    def CreateColVector(self):\n",
    "        return self.CreateRowVector()\n",
    "    def Mult(self, b, u):\n",
    "        self.tv.data = self.All_inv * b\n",
    "        b.Distribute()\n",
    "        self.btilde.data = b - self.A * self.tv\n",
    "        u.data = self.Sp_inv * self.btilde\n",
    "        self.tv.data =  b - self.A * u\n",
    "        u.data += self.All_inv * self.tv"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%px\n",
    "A_dp_inv2 = LocGlobInverse(A_dp, fes.FreeDofs(), \\\n",
    "                           'sparsecholesky', 'mumps')\n",
    "F2 = B @ A_dp_inv2 @ B.T"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Solve for $\\lambda$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[stdout:10] \n",
      "it =  0  err =  0.33241357601342986\n",
      "it =  1  err =  0.08073724692259508\n",
      "it =  2  err =  0.0334260372926967\n",
      "it =  3  err =  0.018590695356589713\n",
      "it =  4  err =  0.0059947474548240785\n",
      "it =  5  err =  0.0027638282106529494\n",
      "it =  6  err =  0.0014339187849704807\n",
      "it =  7  err =  0.0005940409981120754\n",
      "it =  8  err =  0.00031258552680941964\n",
      "it =  9  err =  0.0001315943979595821\n",
      "it =  10  err =  5.087321086144511e-05\n",
      "it =  11  err =  2.2592846810938306e-05\n",
      "it =  12  err =  8.33864967253665e-06\n",
      "it =  13  err =  3.884661119865963e-06\n",
      "it =  14  err =  1.5186884566911173e-06\n",
      "it =  15  err =  6.629027204434756e-07\n",
      "\n",
      "it =  0  err =  0.16620678800671493\n",
      "it =  1  err =  0.04036862346129754\n",
      "it =  2  err =  0.01671301864634835\n",
      "it =  3  err =  0.009295347678294857\n",
      "it =  4  err =  0.002997373727412039\n",
      "it =  5  err =  0.0013819141053264743\n",
      "it =  6  err =  0.0007169593924852402\n",
      "it =  7  err =  0.0002970204990560376\n",
      "it =  8  err =  0.00015629276340470955\n",
      "it =  9  err =  6.5797198979791e-05\n",
      "it =  10  err =  2.5436605430722535e-05\n",
      "it =  11  err =  1.1296423405469144e-05\n",
      "it =  12  err =  4.169324836268321e-06\n",
      "it =  13  err =  1.94233055993298e-06\n",
      "it =  14  err =  7.59344228345558e-07\n",
      "it =  15  err =  3.3145136022173774e-07\n",
      "\n",
      "it =  0  err =  0.16620678800671493\n",
      "it =  1  err =  0.040368623461284836\n",
      "it =  2  err =  0.016713018646346594\n",
      "it =  3  err =  0.009295347678294864\n",
      "it =  4  err =  0.002997373727412205\n",
      "it =  5  err =  0.0013819141053272297\n",
      "it =  6  err =  0.0007169593924864385\n",
      "it =  7  err =  0.0002970204990560735\n",
      "it =  8  err =  0.00015629276340464927\n",
      "it =  9  err =  6.579719897977235e-05\n",
      "it =  10  err =  2.54366054307179e-05\n",
      "it =  11  err =  1.1296423405458421e-05\n",
      "it =  12  err =  4.169324836254944e-06\n",
      "it =  13  err =  1.942330559925504e-06\n",
      "it =  14  err =  7.593442283447208e-07\n",
      "it =  15  err =  3.3145136022149025e-07\n",
      "\n",
      "time solve v1:  0.5127908659633249\n",
      "time solve v2:  0.47894827101845294\n",
      "time solve v3:  0.14496840198989958\n"
     ]
    }
   ],
   "source": [
    "%%px\n",
    "\n",
    "tsolve1 = -comm.WTime()\n",
    "solvers.CG(mat=F, pre=Fhat, rhs=rhs, sol=lam, maxsteps=100, printrates=comm.rank==0, tol=1e-6)\n",
    "tsolve1 += comm.WTime()\n",
    "\n",
    "if comm.rank==0:\n",
    "    print('')\n",
    "    \n",
    "tsolve2 = -comm.WTime()\n",
    "solvers.CG(mat=F, pre=Fhat2, rhs=rhs, sol=lam, maxsteps=100, printrates=comm.rank==0, tol=1e-6)\n",
    "tsolve2 += comm.WTime()\n",
    "\n",
    "if comm.rank==0:\n",
    "    print('')\n",
    "\n",
    "tsolve3 = -comm.WTime()\n",
    "solvers.CG(mat=F2, pre=Fhat2, rhs=rhs, sol=lam, maxsteps=100, printrates=comm.rank==0, tol=1e-6)\n",
    "tsolve3 += comm.WTime()\n",
    "\n",
    "if comm.rank==0:\n",
    "    print('\\ntime solve v1: ', tsolve1)\n",
    "    print('time solve v2: ', tsolve2)\n",
    "    print('time solve v3: ', tsolve3)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Not that the normal and scaled variants of the Dirichlet-Preconditioner\n",
    "are equivalent here, because all DOFs are either local, primal, or shared\n",
    "by **exactly two** ranks.\n",
    "\n",
    "\n",
    "\n",
    "### Reconstruct $u$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[stdout:10] norm jump u:  2.7631462978367745e-07\n"
     ]
    }
   ],
   "source": [
    "%%px\n",
    "\n",
    "gfu = GridFunction(fes)\n",
    "hv.data = f.vec - B.T * lam\n",
    "gfu.vec.data = A_dp_inv * hv\n",
    "\n",
    "jump = lam.CreateVector()\n",
    "jump.data = B * gfu.vec\n",
    "norm_jump = Norm(jump)\n",
    "if comm.rank==0:\n",
    "    print('norm jump u: ', norm_jump)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Mumps Inverse :   2.067965030670166\n",
      "Mumps Inverse - analysis :   1.8604109287261963\n",
      "Parallelmumps mult inverse :   1.8263428211212158\n"
     ]
    }
   ],
   "source": [
    "%%px --target 1\n",
    "for t in sorted(filter(lambda t:t['time']>0.5, Timers()), key=lambda t:t['time'], reverse=True):\n",
    "    print(t['name'], ':  ', t['time'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[stdout:36] \n",
      "Mumps Inverse :   2.035609245300293\n",
      "Mumps Inverse - analysis :   1.8604459762573242\n",
      "Parallelmumps mult inverse :   1.8067901134490967\n",
      "SparseCholesky<d,d,d>::MultAdd :   0.1947925090789795\n",
      "Mumps Inverse - factor :   0.1496748924255371\n",
      "SparseCholesky<d,d,d>::MultAdd fac1 :   0.12263321876525879\n",
      "SparseCholesky<d,d,d>::MultAdd fac2 :   0.06641483306884766\n"
     ]
    }
   ],
   "source": [
    "%%px\n",
    "t_chol = filter(lambda t: t['name'] == 'SparseCholesky<d,d,d>::MultAdd', Timers()).__next__()\n",
    "maxt = comm.Max(t_chol['time']) \n",
    "if t_chol['time'] == maxt:\n",
    "    for t in sorted(filter(lambda t:t['time']>0.3*maxt, Timers()), key=lambda t:t['time'], reverse=True):\n",
    "        print(t['name'], ':  ', t['time'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "stop_cluster()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Experiment Below!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Waiting for connection file: ~/.ipython/profile_ngsolve/security/ipcontroller-kogler-client.json\n",
      "connecting ... try:6 succeeded!"
     ]
    }
   ],
   "source": [
    "from usrmeeting_jupyterstuff import *\n",
    "num_procs = '80'\n",
    "start_cluster(num_procs)\n",
    "connect_cluster()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%px\n",
    "from ngsolve import *\n",
    "import netgen.meshing as ngmeshing\n",
    "from ngsolve.la import ParallelMatrix, FETI_Jump\n",
    "from dd_toolbox import LocGlobInverse, ScaledMat\n",
    "\n",
    "def load_mesh(nref=0):\n",
    "    ngmesh = ngmeshing.Mesh(dim=2)\n",
    "    ngmesh.Load('squaref.vol')\n",
    "    for l in range(nref):\n",
    "        ngmesh.Refine()\n",
    "    return Mesh(ngmesh)\n",
    "\n",
    "def setup_space(mesh, order=1):\n",
    "    comm = MPI_Init()\n",
    "    fes = H1(mesh, order=order, dirichlet='right|top')\n",
    "    a = BilinearForm(fes)\n",
    "    u,v = fes.TnT()\n",
    "    a += SymbolicBFI(grad(u)*grad(v))\n",
    "    a.Assemble()\n",
    "    f = LinearForm(fes)\n",
    "    f += SymbolicLFI(x*y*v)\n",
    "    f.Assemble()\n",
    "    avg_dof = comm.Sum(fes.ndof) / comm.size\n",
    "    if comm.rank==0:\n",
    "        print('global,  ndof =', fes.ndofglobal, ', lodofs =', fes.lospace.ndofglobal)\n",
    "        print('avg DOFs per core: ', avg_dof)\n",
    "    return [fes, a, f]\n",
    "\n",
    "def setup_FETIDP(fes, a):\n",
    "    primal_dofs = BitArray([len(fes.ParallelDofs().Dof2Proc(k))>1 for k in range(fes.ndof)]) & fes.FreeDofs()\n",
    "    nprim = comm.Sum(sum([1 for k in range(fes.ndof) if primal_dofs[k] \\\n",
    "                          and comm.rank<fes.ParallelDofs().Dof2Proc(k)[0] ]))\n",
    "    if comm.rank==0:\n",
    "        print('# primal dofs global: ', nprim)  \n",
    "    dp_pardofs = fes.ParallelDofs().SubSet(primal_dofs)\n",
    "    A_dp = ParallelMatrix(a.mat.local_mat, dp_pardofs)\n",
    "    A_dp_inv = A_dp.Inverse(fes.FreeDofs(), inverse='mumps')  \n",
    "    dual_pardofs = fes.ParallelDofs().SubSet(BitArray(~primal_dofs & fes.FreeDofs()))\n",
    "    B = FETI_Jump(dual_pardofs, u_pardofs=dp_pardofs)\n",
    "    if comm.rank==0:\n",
    "        print('# of global multipliers = :', B.col_pardofs.ndofglobal)\n",
    "    F = B @ A_dp_inv @ B.T\n",
    "    A_dp_inv2 = LocGlobInverse(A_dp, fes.FreeDofs(), \\\n",
    "                               invtype_loc='sparsecholesky',\\\n",
    "                               invtype_glob='masterinverse')\n",
    "    F2 = B @ A_dp_inv2 @ B.T\n",
    "    innerdofs = BitArray([len(fes.ParallelDofs().Dof2Proc(k))==0 for k in range(fes.ndof)]) & fes.FreeDofs()\n",
    "    A = a.mat.local_mat\n",
    "    Aiinv = A.Inverse(innerdofs, inverse='sparsecholesky')\n",
    "    Fhat = B @ A @ (IdentityMatrix() - Aiinv @ A) @ B.T\n",
    "    scaledA = ScaledMat(A, [1.0/(1+len(fes.ParallelDofs().Dof2Proc(k))) for k in range(fes.ndof)])\n",
    "    scaledBT = ScaledMat(B.T, [1.0/(1+len(fes.ParallelDofs().Dof2Proc(k))) for k in range(fes.ndof)])\n",
    "    Fhat2 = B @ scaledA @ (IdentityMatrix() - Aiinv @ A) @ scaledBT\n",
    "    return [A_dp, A_dp_inv, A_dp_inv2, F, F2, Fhat, Fhat2, B, scaledA, scaledBT]\n",
    "    \n",
    "def prep(B, Ainv, f):\n",
    "    rhs.data = (B @ Ainv) * f.vec\n",
    "    return rhs\n",
    "\n",
    "def solve(mat, pre, rhs, sol):\n",
    "    t = comm.WTime()\n",
    "    solvers.CG(mat=mat, pre=pre, rhs=rhs, sol=sol, \\\n",
    "               maxsteps=100, printrates=comm.rank==0, tol=1e-6)\n",
    "    return comm.WTime() - t\n",
    "    \n",
    "def post(B, Ainv, gfu, lam):\n",
    "    hv = B.CreateRowVector()\n",
    "    hv.data = f.vec - B.T * lam\n",
    "    gfu.vec.data = Ainv * hv\n",
    "    jump = lam.CreateVector()\n",
    "    jump.data = B * gfu.vec\n",
    "    norm_jump = Norm(jump)\n",
    "    if comm.rank==0:\n",
    "        print('norm jump u: ', norm_jump)   "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[stdout:1] \n",
      "global,  ndof = 187833 , lodofs = 47161\n",
      "avg DOFs per core:  2441.0875\n",
      "# primal dofs global:  125\n",
      "# of global multipliers = : 7188\n",
      "\n",
      "it =  0  err =  0.2775003436443844\n",
      "it =  1  err =  0.059395434678847045\n",
      "it =  2  err =  0.02668466647171506\n",
      "it =  3  err =  0.013658995259387977\n",
      "it =  4  err =  0.0049131733239784845\n",
      "it =  5  err =  0.0021229848364316215\n",
      "it =  6  err =  0.0009537456623955999\n",
      "it =  7  err =  0.0003609947384913908\n",
      "it =  8  err =  0.0001589874932742381\n",
      "it =  9  err =  6.398790188681329e-05\n",
      "it =  10  err =  2.5454548026516644e-05\n",
      "it =  11  err =  9.979151075046508e-06\n",
      "it =  12  err =  4.614085390943377e-06\n",
      "it =  13  err =  1.8339700212425154e-06\n",
      "it =  14  err =  7.44343568078797e-07\n",
      "it =  15  err =  3.473053788503813e-07\n",
      "\n",
      "it =  0  err =  0.1387501718221922\n",
      "it =  1  err =  0.029697717339423522\n",
      "it =  2  err =  0.01334233323585753\n",
      "it =  3  err =  0.006829497629693989\n",
      "it =  4  err =  0.002456586661989242\n",
      "it =  5  err =  0.0010614924182158108\n",
      "it =  6  err =  0.00047687283119779986\n",
      "it =  7  err =  0.00018049736924569538\n",
      "it =  8  err =  7.949374663711905e-05\n",
      "it =  9  err =  3.1993950943406623e-05\n",
      "it =  10  err =  1.2727274013258324e-05\n",
      "it =  11  err =  4.989575537523253e-06\n",
      "it =  12  err =  2.3070426954716886e-06\n",
      "it =  13  err =  9.169850106212577e-07\n",
      "it =  14  err =  3.721717840393986e-07\n",
      "it =  15  err =  1.7365268942519067e-07\n",
      "\n",
      "it =  0  err =  0.1387501718221922\n",
      "it =  1  err =  0.02969771733942429\n",
      "it =  2  err =  0.013342333235858251\n",
      "it =  3  err =  0.006829497629693326\n",
      "it =  4  err =  0.002456586661989025\n",
      "it =  5  err =  0.0010614924182156267\n",
      "it =  6  err =  0.00047687283119771643\n",
      "it =  7  err =  0.00018049736924568807\n",
      "it =  8  err =  7.949374663711586e-05\n",
      "it =  9  err =  3.19939509434018e-05\n",
      "it =  10  err =  1.2727274013256248e-05\n",
      "it =  11  err =  4.98957553752233e-06\n",
      "it =  12  err =  2.3070426954715653e-06\n",
      "it =  13  err =  9.16985010621377e-07\n",
      "it =  14  err =  3.7217178403948777e-07\n",
      "it =  15  err =  1.736526894252284e-07\n",
      "\n",
      "time solve v1:  0.4511956589994952\n",
      "time solve v2:  0.48503338207956403\n",
      "time solve v3:  0.06109175400342792\n",
      "norm jump u:  1.0723427980831504e-07\n"
     ]
    }
   ],
   "source": [
    "%%px\n",
    "comm = MPI_Init()\n",
    "mesh = load_mesh(nref=1)\n",
    "fes, a, f = setup_space(mesh, order=2)\n",
    "A_dp, A_dp_inv, A_dp_inv2, F, F2, Fhat, Fhat2, B, scaledA, scaledBT = setup_FETIDP(fes, a)\n",
    "rhs = B.CreateColVector()\n",
    "lam = B.CreateColVector()\n",
    "prep(B, A_dp_inv2, f)\n",
    "if comm.rank==0:\n",
    "    print('')\n",
    "t1 = solve(F,  Fhat,  rhs, lam)\n",
    "if comm.rank==0:\n",
    "    print('')\n",
    "t2 = solve(F,  Fhat2, rhs, lam)\n",
    "if comm.rank==0:\n",
    "    print('')\n",
    "t3 = solve(F2, Fhat2, rhs, lam)\n",
    "if comm.rank==0:\n",
    "    print('\\ntime solve v1: ', t1)\n",
    "    print('time solve v2: ', t2)\n",
    "    print('time solve v3: ', t3)\n",
    "gfu = GridFunction(fes)\n",
    "post(B, A_dp_inv2, gfu, lam)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Mumps Inverse :   2.0638580322265625\n",
      "Mumps Inverse - analysis :   1.8557240962982178\n",
      "Parallelmumps mult inverse :   0.8613486289978027\n"
     ]
    }
   ],
   "source": [
    "%%px --target 1\n",
    "for t in sorted(filter(lambda t:t['time']>0.5, Timers()), key=lambda t:t['time'], reverse=True):\n",
    "    print(t['name'], ':  ', t['time'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[stdout:12] \n",
      "timers from rank  14 :\n",
      "Mumps Inverse :   2.063915967941284\n",
      "Mumps Inverse - analysis :   1.855010986328125\n",
      "Parallelmumps mult inverse :   0.8388888835906982\n",
      "dummy - AllReduce :   0.26943492889404297\n",
      "Mumps Inverse - factor :   0.19652891159057617\n",
      "dummy - AllReduce :   0.08971333503723145\n",
      "SparseCholesky<d,d,d>::MultAdd :   0.059799909591674805\n",
      "SparseCholesky<d,d,d>::MultAdd fac1 :   0.03657698631286621\n",
      "SparseCholesky<d,d,d>::MultAdd fac2 :   0.021203994750976562\n",
      "SparseCholesky - total :   0.019064903259277344\n"
     ]
    }
   ],
   "source": [
    "%%px\n",
    "t_chol = filter(lambda t: t['name'] == 'SparseCholesky<d,d,d>::MultAdd', Timers()).__next__()\n",
    "maxt = comm.Max(t_chol['time']) \n",
    "if t_chol['time'] == maxt:\n",
    "    print('timers from rank ', comm.rank, ':')\n",
    "    for t in sorted(filter(lambda t:t['time']>min(0.3*maxt, 0.5), Timers()), key=lambda t:t['time'], reverse=True):\n",
    "        print(t['name'], ':  ', t['time'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [],
   "source": [
    "stop_cluster()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
